{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RLHF Implementation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install and import packages and libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %pip install --upgrade pip\n",
    "# %pip install transformers trl wandb\n",
    "#%pip install plotly -qqq"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm\n",
    "import pandas as pd\n",
    "\n",
    "tqdm.pandas()\n",
    "\n",
    "from trl import PPOTrainer, PPOConfig, AutoModelForCausalLMWithValueHead,RewardTrainer\n",
    "from trl.core import LengthSampler\n",
    "import random\n",
    "from operator import itemgetter\n",
    "from datasets import Dataset, load_dataset\n",
    "from transformers import (\n",
    "    AutoModelForSequenceClassification,\n",
    "    AutoTokenizer,\n",
    "    TrainingArguments,\n",
    "    pipeline,\n",
    ")\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reward Model Training"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fetch and process dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>instruction</th>\n",
       "      <th>chosen_response</th>\n",
       "      <th>rejected_response</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Can Instana use OpenTelemetry trace data?​</td>\n",
       "      <td>Yes, Instana can use OpenTelemetry trace data....</td>\n",
       "      <td>Answer: Yes, Instana can use OpenTelemetry tr...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Can Instana use OpenTelemetry trace data?​​​​​</td>\n",
       "      <td>Yes, Instana can ingest OpenTelemetry trace da...</td>\n",
       "      <td>Answer: Yes, Instana can use OpenTelemetry tr...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Can Maximo Visual Inspection run on prem?​​​</td>\n",
       "      <td>Answer: Yes. Maximo Visual Inspection is a cl...</td>\n",
       "      <td>Answer: Yes, Maximo Visual Inspection can be ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Explain me step by step how can I integrate da...</td>\n",
       "      <td>ive tried to follow the documentation but I do...</td>\n",
       "      <td>ive tried to do it but it is not working. Answ...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>What is watson knowledge catalog?</td>\n",
       "      <td>The IBM Watson Knowledge Catalog is a data cat...</td>\n",
       "      <td>Answer: The IBM Watson Knowledge Catalog is a...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>5</th>\n",
       "      <td>what is cloud pak for watson aiops</td>\n",
       "      <td>? Answer: Cloud Pak for Watson AIOps is an AI-...</td>\n",
       "      <td>? Answer: Cloud Pak for Watson AIOps is a plat...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>6</th>\n",
       "      <td>what is ibm?</td>\n",
       "      <td>Answer: IBM is an American multinational techn...</td>\n",
       "      <td>Answer: IBM is an American multinational techn...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>7</th>\n",
       "      <td>which cloudpak provides business process autom...</td>\n",
       "      <td>? Answer: IBM Cloud Pak for Automation</td>\n",
       "      <td>? : What is the difference between IBM Cloud P...</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>8</th>\n",
       "      <td>who is the CEO of IBM?</td>\n",
       "      <td>Answer: Arvind Krishna Answer: Arvind Krishna</td>\n",
       "      <td>.</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                         instruction  \\\n",
       "0         Can Instana use OpenTelemetry trace data?​   \n",
       "1     Can Instana use OpenTelemetry trace data?​​​​​   \n",
       "2       Can Maximo Visual Inspection run on prem?​​​   \n",
       "3  Explain me step by step how can I integrate da...   \n",
       "4                  What is watson knowledge catalog?   \n",
       "5                 what is cloud pak for watson aiops   \n",
       "6                                       what is ibm?   \n",
       "7  which cloudpak provides business process autom...   \n",
       "8                             who is the CEO of IBM?   \n",
       "\n",
       "                                     chosen_response  \\\n",
       "0  Yes, Instana can use OpenTelemetry trace data....   \n",
       "1  Yes, Instana can ingest OpenTelemetry trace da...   \n",
       "2   Answer: Yes. Maximo Visual Inspection is a cl...   \n",
       "3  ive tried to follow the documentation but I do...   \n",
       "4  The IBM Watson Knowledge Catalog is a data cat...   \n",
       "5  ? Answer: Cloud Pak for Watson AIOps is an AI-...   \n",
       "6  Answer: IBM is an American multinational techn...   \n",
       "7            ? Answer: IBM Cloud Pak for Automation    \n",
       "8     Answer: Arvind Krishna Answer: Arvind Krishna    \n",
       "\n",
       "                                   rejected_response  \n",
       "0   Answer: Yes, Instana can use OpenTelemetry tr...  \n",
       "1   Answer: Yes, Instana can use OpenTelemetry tr...  \n",
       "2   Answer: Yes, Maximo Visual Inspection can be ...  \n",
       "3  ive tried to do it but it is not working. Answ...  \n",
       "4   Answer: The IBM Watson Knowledge Catalog is a...  \n",
       "5  ? Answer: Cloud Pak for Watson AIOps is a plat...  \n",
       "6  Answer: IBM is an American multinational techn...  \n",
       "7  ? : What is the difference between IBM Cloud P...  \n",
       "8                                                  .  "
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "df = pd.read_csv('feedback.csv')\n",
    "df['tup'] = list(zip(df['answer'], df['feedback']))\n",
    "df_g = df.groupby('question')['tup'].apply(list).reset_index()\n",
    "\n",
    "df_g[\"sorted_tup\"] = df_g[\"tup\"].apply(lambda x :sorted(x,key=itemgetter(1)) )\n",
    "df_g[\"chosen\"] = df_g[\"sorted_tup\"].apply(lambda x: x[-1][0])\n",
    "df_g[\"chosen_score\"] = df_g[\"sorted_tup\"].apply(lambda x: x[-1][1])\n",
    "df_g[\"rejected\"] = df_g[\"sorted_tup\"].apply(lambda x: x[0][0])\n",
    "df_g[\"rejected_score\"] = df_g[\"sorted_tup\"].apply(lambda x: x[0][1])\n",
    "df_g = df_g.dropna()\n",
    "df_g.to_csv(\"treated_feedback.csv\")\n",
    "\n",
    "df_g = df_g[(df_g['chosen_score']>=4.0) & (df_g['rejected_score']<4.0)]\n",
    "#df_g.to_csv(\"treated_feedback.csv\")\n",
    "\n",
    "# build a dataset with chosen and rejected responses\n",
    "rows = []\n",
    "for record in df_g.itertuples(index=True, name='Pandas'):\n",
    "    if record is None or len(record) == 0:\n",
    "        continue\n",
    "    # build rows for rm training\n",
    "    rows.append({\n",
    "        \"instruction\": record.question,\n",
    "        \"chosen_response\": record.chosen,\n",
    "        \"rejected_response\": record.rejected\n",
    "    })\n",
    "\n",
    "# build dataset for training\n",
    "prepared_dataset = Dataset.from_list(rows)\n",
    "prepared_dataset.to_pandas()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Configuing the base model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.dense.weight', 'lm_head.layer_norm.weight', 'roberta.pooler.dense.bias', 'lm_head.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'lm_head.layer_norm.bias', 'lm_head.dense.bias']\n",
      "- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.weight', 'classifier.out_proj.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/9 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "#To train Reward Model we need to choose a base model to fine-tune.\n",
    "rm_model_name = \"distilroberta-base\"\n",
    "\n",
    "rm_model = AutoModelForSequenceClassification.from_pretrained(rm_model_name, num_labels=1)\n",
    "rm_tokenizer = AutoTokenizer.from_pretrained(rm_model_name)\n",
    "\n",
    "if rm_tokenizer.pad_token is None:\n",
    "    rm_tokenizer.pad_token = rm_tokenizer.eos_token\n",
    "    rm_model.config.pad_token_id = rm_model.config.eos_token_id\n",
    "#This function combines instructions with chosen and rejected responses, creating two new strings.\n",
    "#These strings are tokenized, becoming input for a reward model that learns to distinguish between good and bad responses based on these examples.\n",
    "#The model will be optimized to assign higher values to preferred responses and lower values to rejected responses.\n",
    "def formatting_func(examples):\n",
    "    kwargs = {\"padding\": \"max_length\", \"truncation\": True, \"max_length\": 512, \"return_tensors\": \"pt\"}\n",
    "\n",
    "    # Prepend the prompt and a line break to the original_response and response-1 fields.\n",
    "    prompt_plus_chosen_response = examples[\"instruction\"] + \"\\n\" + examples[\"chosen_response\"]\n",
    "    prompt_plus_rejected_response = examples[\"instruction\"] + \"\\n\" + examples[\"rejected_response\"]\n",
    "\n",
    "    # Then tokenize these modified fields.\n",
    "    tokens_chosen = rm_tokenizer.encode_plus(prompt_plus_chosen_response, **kwargs)\n",
    "    tokens_rejected = rm_tokenizer.encode_plus(prompt_plus_rejected_response, **kwargs)\n",
    "\n",
    "    return {\n",
    "        \"input_ids_chosen\": tokens_chosen[\"input_ids\"][0], \"attention_mask_chosen\": tokens_chosen[\"attention_mask\"][0],\n",
    "        \"input_ids_rejected\": tokens_rejected[\"input_ids\"][0], \"attention_mask_rejected\": tokens_rejected[\"attention_mask\"][0]\n",
    "    }\n",
    "\n",
    "formatted_dataset = prepared_dataset.map(formatting_func)\n",
    "formatted_dataset = formatted_dataset.train_test_split()\n",
    "\n",
    "rm_training_args = TrainingArguments(\n",
    "    output_dir=\"./reward_model\",\n",
    "    per_device_train_batch_size=16,\n",
    "    evaluation_strategy=\"steps\",\n",
    "    logging_steps=200,\n",
    "    num_train_epochs = 1,\n",
    "\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Training the reward model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "rm_trainer = RewardTrainer(\n",
    "    model=rm_model,\n",
    "    args=rm_training_args,\n",
    "    tokenizer=rm_tokenizer,\n",
    "    train_dataset=formatted_dataset[\"train\"],\n",
    "    eval_dataset=formatted_dataset[\"test\"],\n",
    ")\n",
    "\n",
    "rm_trainer.train() "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Save the trained reward model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "rm_trainer.save_model() "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Conifguring the model to be finetuned using RL"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "config = PPOConfig(\n",
    "    model_name=\"lvwerra/gpt2-imdb\",\n",
    "    learning_rate=1.41e-5,\n",
    "    #log_with=\"wandb\",\n",
    ")\n",
    "\n",
    "sent_kwargs = {\"return_all_scores\": True, \"function_to_apply\": \"none\", \"batch_size\": 16}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fetch data "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def build_dataset(config, dataset_name=\"gsm8k\", input_min_text_length=2, input_max_text_length=200):\n",
    "    \"\"\"\n",
    "    Build dataset for training. This builds the dataset from `load_dataset`, one should\n",
    "    customize this function to train the model on its own dataset.\n",
    "\n",
    "    Args:\n",
    "        dataset_name (`str`):\n",
    "            The name of the dataset to be loaded.\n",
    "\n",
    "    Returns:\n",
    "        dataloader (`torch.utils.data.DataLoader`):\n",
    "            The dataloader for the dataset.\n",
    "    \"\"\"\n",
    "    tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n",
    "    tokenizer.pad_token = tokenizer.eos_token\n",
    "    # load imdb with datasets\n",
    "\n",
    "    #ds = load_dataset(dataset_name,'main', split=\"train[:10%]\")\n",
    "\n",
    "    df = pd.read_csv(\"feedback.csv\")\n",
    "    ds = Dataset.from_pandas(df)\n",
    "\n",
    "    ds = ds.rename_columns({\"question\": \"review\"})\n",
    "    #ds = ds.filter(lambda x: len(x[\"review\"]) > 20, batched=False)\n",
    "\n",
    "    input_size = LengthSampler(input_min_text_length, input_max_text_length)\n",
    "\n",
    "    def tokenize(sample):\n",
    "        sample[\"input_ids\"] = tokenizer.encode(sample[\"review\"])[: input_size()]\n",
    "        sample[\"query\"] = tokenizer.decode(sample[\"input_ids\"])\n",
    "        return sample\n",
    "\n",
    "    ds = ds.map(tokenize, batched=False)\n",
    "    ds.set_format(type=\"torch\")\n",
    "    return ds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Map:   0%|          | 0/268 [00:00<?, ? examples/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "dataset = build_dataset(config)\n",
    "\n",
    "\n",
    "def collator(data):\n",
    "    return dict((key, [d[key] for d in data]) for key in data[0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'review': ['Can Maximo Visual Inspection run on prem?\\u200b\\u200b\\u200b',\n",
       "  'What is watson knowledge catalog?',\n",
       "  'What is watson knowledge catalog?',\n",
       "  'Can Instana use OpenTelemetry trace data?\\u200b\\u200b\\u200b\\u200b\\u200b',\n",
       "  'format it on a table',\n",
       "  'what are the differences between Watson Orchestrate and other IBM automation tools?',\n",
       "  'what are the differences between Watson Orchestrate and other IBM automation tools?',\n",
       "  'What Watson Orchestrate is for?',\n",
       "  'What Watson Orchestrate is for?',\n",
       "  'what are the differences between Watson Orchestrate and other IBM automation tools?'],\n",
       " 'answer': [' Answer: Yes, Maximo Visual Inspection can be run on premise. ',\n",
       "  ' Answer: The IBM Watson Knowledge Catalog is a data catalog that serves as a single version of the truth for different users including data engineers, business analysts, data analysts, data scientists, and data citizens. Users can gain access to data they can trust, govern, curate, share and manage within an organization. Beyond that, you can add data policies and rules around your data to ensure your information doesnt get into the wrong hands and is compliant. Take control of your data by protecting your sensitive information and tracing the lineage of your data. You should always be certain of where the data came from and how it is being used in your organization. The data governance, data quality, and policy management capabilities within the IBM Watson Knowledge Catalog will help make your data business-ready. The journey to AI starts here, with clean and trusted data. Is your data business-ready? Keep reading to find tips on how to leverage the IBM Watson Knowledge Catalog to get trusted data. One of the',\n",
       "  ' Answer: The IBM Watson Knowledge Catalog is a data catalog that serves as a single version of the truth for different users including data engineers, business analysts, data analysts, data scientists, and data citizens. Users can gain access to data they can trust, govern, curate, share and manage within an organization. Beyond that, you can add data policies and rules around your data to ensure your information doesnt get into the wrong hands and is compliant. Take control of your data by protecting your sensitive information and tracing the lineage of your data. You should always be certain of where the data came from and how it is being used in your organization. The data governance, data quality, and policy management capabilities within the IBM Watson Knowledge Catalog will help make your data business-ready. The journey to AI starts here, with clean and trusted data. Is your data business-ready? Keep reading to find tips on how to leverage the IBM Watson Knowledge Catalog to get trusted data. One of the',\n",
       "  ' Answer: Yes, Instana can ingest OpenTelemetry trace data. Instana supports OpenTelemetry trace data from any OpenTelemetry-compatible application. Instana can ingest trace data from any application that is instrumented with OpenTelemetry. OpenTelemetry is an open-source standard for instrumenting, collecting, and aggregating telemetry data from applications. OpenTelemetry is a standard for instrumenting, collecting, and aggregating telemetry data from applications. OpenTelemetry is a standard for instrumenting, collecting, and aggregating telemetry data from applications. OpenTelemetry is an open-source standard for instrumenting, collecting, and aggregating telemetry data from applications. OpenTelemetry is a standard for instrumenting, collecting, and aggregating telemetry data from applications. OpenTelemetry is an open-source standard for instrumenting, collecting, and aggregating telemetry data from applications. OpenTelemetry is an open-source standard for instrumenting, collecting, and aggregating',\n",
       "  'with the following columns. The first column is the id, the second column is the name, the third column is the phone number, the fourth column is the lazy boolean value, and the fifth column is the married boolean value. Answer: The following SQL statement creates a table with the required columns and inserts the sample data from Listing 3 into the table. The SQL statement uses the JSON_TABLE function to extract the required information from the JSON object. The JSON_TABLE function is used in the FROM clause of the SQL statement. The SQL statement uses a JSON path expression to specify the columns of the table. The JSON path expression is a sequence of SQL/JSON path expressions separated by the comma operator. The SQL/JSON path expression is a sequence of SQL/JSON path expressions separated by the comma operator. The SQL/JSON path expression is a sequence of SQL/JSON path expressions separated by the comma operator. The SQL/JSON path expression is a sequence of SQL/JSON path expressions separated by the',\n",
       "  ' Answer: Watson Orchestrate is a cloud-native, open-source, and open API orchestration platform that accelerates the development and deployment of cloud-native applications and workflows. It is a modern orchestration platform that is built on Kubernetes and is designed to manage and orchestrate complex workflows across hybrid cloud environments. It is a cloud-native orchestration platform that is designed to manage and orchestrate complex workflows across hybrid cloud environments. IBM Integration, IBM Business Automation, and IBM IT Automation products are available on AWS. IBM Integration products on AWS: IBM Integration Bus, IBM Integration Bus for z/OS, IBM Integration Bus for z/OS on Linux, IBM Integration Bus for z/OS on Linux on Red Hat OpenShift, IBM Integration Bus for z/OS on Linux on Red Hat OpenShift on IBM Cloud, IBM Integration Bus for z/OS on Linux on Red Hat OpenShift on IBM Cloud on Red Hat OpenShift, IBM Integration Bus',\n",
       "  ' Answer: Watson Orchestrate is a cloud-native, open-source orchestration platform that enables you to automate and manage complex workflows across hybrid cloud environments. It is a fully managed service that can be deployed in AWS, Azure, and GCP. It is built on Kubernetes and can be used with any cloud-native application. It can be used to orchestrate workflows across multiple cloud-native applications, including IBM Cloud Pak for Integration, IBM Cloud Pak for Data, IBM Cloud Pak for Applications, and IBM Cloud Pak for Multicloud Management. It can also be used to orchestrate workflows across multiple cloud-native applications, including IBM Cloud Pak for Integration, IBM Cloud Pak for Data, IBM Cloud Pak for Applications, and IBM Cloud Pak for Multicloud Management. Watson Orchestrate is a cloud-native, open-source orchestration platform that enables you to automate and manage complex workflows across hybrid cloud environments. It is a fully managed service that can be deployed in',\n",
       "  ' : What is a data fabric? : What is a data lake? ',\n",
       "  ' Answer: Watson Orchestrate is a cloud-based service that helps you automate your data science and machine learning workflows. ',\n",
       "  ' Answer: Watson Orchestrate is a service that provides a unified orchestration platform for automating the deployment, management, and governance of cloud-native applications. It is a cloud-native, open-source, multi-cloud orchestration platform that can be used to automate the deployment, management, and governance of cloud-native applications. It is a cloud-native, open-source, multi-cloud orchestration platform that can be used to automate the deployment, management, and governance of cloud-native applications. It is built on Kubernetes and is designed to work with any cloud provider. It is built on Kubernetes and is designed to work with any cloud provider. It is a cloud-native, open-source, multi-cloud orchestration platform that can be used to automate the deployment, management, and governance of cloud-native applications. It is a cloud-native, open-source, multi-cloud orchestration platform that can be used to automate the deployment, management, and governance of cloud-native applications. It is'],\n",
       " 'feedback': tensor([2., 4., 3., 4., 1., 2., 2., 1., 3., 3.]),\n",
       " 'input_ids': [tensor([ 6090,  5436, 25147, 15612, 47115,  1057,   319,  4199,    30, 39009,\n",
       "           9525]),\n",
       "  tensor([ 2061,   318,   266, 13506,  3725, 18388,    30]),\n",
       "  tensor([ 2061,   318,   266, 13506,  3725, 18388,    30]),\n",
       "  tensor([ 6090,  2262,  2271,   779,  4946, 31709, 41935, 12854,  1366,    30,\n",
       "          39009, 39009,  9525]),\n",
       "  tensor([18982,   340,   319,   257,  3084]),\n",
       "  tensor([10919,   389,   262,  5400,  1022, 14959, 30369, 23104,   290,   584,\n",
       "          19764, 22771,  4899,    30]),\n",
       "  tensor([10919,   389,   262,  5400,  1022, 14959, 30369, 23104,   290,   584,\n",
       "          19764, 22771,  4899,    30]),\n",
       "  tensor([ 2061, 14959, 30369, 23104,   318,   329,    30]),\n",
       "  tensor([ 2061, 14959, 30369, 23104,   318,   329,    30]),\n",
       "  tensor([10919,   389,   262,  5400,  1022, 14959, 30369, 23104,   290,   584,\n",
       "          19764, 22771,  4899,    30])],\n",
       " 'query': ['Can Maximo Visual Inspection run on prem?\\u200b\\u200b\\u200b',\n",
       "  'What is watson knowledge catalog?',\n",
       "  'What is watson knowledge catalog?',\n",
       "  'Can Instana use OpenTelemetry trace data?\\u200b\\u200b\\u200b\\u200b\\u200b',\n",
       "  'format it on a table',\n",
       "  'what are the differences between Watson Orchestrate and other IBM automation tools?',\n",
       "  'what are the differences between Watson Orchestrate and other IBM automation tools?',\n",
       "  'What Watson Orchestrate is for?',\n",
       "  'What Watson Orchestrate is for?',\n",
       "  'what are the differences between Watson Orchestrate and other IBM automation tools?']}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "dataset[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n",
    "ref_model = AutoModelForCausalLMWithValueHead.from_pretrained(config.model_name)\n",
    "tokenizer = AutoTokenizer.from_pretrained(config.model_name)\n",
    "\n",
    "from transformers import T5Tokenizer, T5ForConditionalGeneration\n",
    "# model = T5ForConditionalGeneration.from_pretrained(\"google/flan-t5-xxl\", device_map=\"auto\")\n",
    "# ref_model =  T5ForConditionalGeneration.from_pretrained(\"google/flan-t5-xxl\", device_map=\"auto\")\n",
    "# tokenizer = T5Tokenizer.from_pretrained(\"google/flan-t5-xxl\")\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "ppo_trainer = PPOTrainer(config, model, ref_model, tokenizer, dataset=dataset, data_collator=collator)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = ppo_trainer.accelerator.device\n",
    "if ppo_trainer.accelerator.num_processes == 1:\n",
    "    device = 0 if torch.cuda.is_available() else \"cpu\"  # to avoid a `pipeline` bug\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "###############"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "rm_model_trained = AutoModelForSequenceClassification.from_pretrained(\"./reward_model\")\n",
    "rm_tokenizer_trained = AutoTokenizer.from_pretrained(\"./reward_model\",padding=True,truncation=True)\n",
    "\n",
    "if rm_tokenizer_trained.pad_token is None:\n",
    "    rm_tokenizer_trained.pad_token = rm_tokenizer_trained.eos_token\n",
    "    rm_model_trained.config.pad_token_id = rm_model_trained.config.eos_token_id\n",
    "    \n",
    "    \n",
    "# text = [\"this is one sentence\", \"this is another sentence\"]\n",
    "# encoding = tokenizer(text, return_tensors=\"pt\")\n",
    "\n",
    "# # forward pass\n",
    "# outputs = model(**encoding)\n",
    "# predictions = outputs.logits.argmax(-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "text = [\"this is really bad sentence\",\"this is really goof  ec ece w\"]\n",
    "rm_tokenizer_trained = AutoTokenizer.from_pretrained(\"./reward_model\")\n",
    "\n",
    "encoding = rm_tokenizer_trained(text, return_tensors=\"pt\",padding=True,truncation=True)\n",
    "\n",
    "# # forward pass\n",
    "# outputs = rm_model_trained(**encoding)\n",
    "# predictions = outputs.logits.argmax(-1)\n",
    "# predictions\n",
    "# outputs.logits"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [],
   "source": [
    "# text = \"this movie was really bad!!\"\n",
    "# sentiment_pipe(text, **sent_kwargs)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# RLHF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen_kwargs = {\"min_length\": -1, \"top_k\": 0.0, \"top_p\": 1.0, \"do_sample\": True, \"pad_token_id\": tokenizer.eos_token_id}\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "0it [00:00, ?it/s]/tmp/ipykernel_3631327/2175772355.py:33: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).\n",
      "  rewards = [torch.tensor(i) for i in outputs.logits]\n",
      "1it [02:03, 123.13s/it]\n"
     ]
    }
   ],
   "source": [
    "output_min_length = 4\n",
    "output_max_length = 16\n",
    "output_length_sampler = LengthSampler(output_min_length, output_max_length)\n",
    "\n",
    "\n",
    "generation_kwargs = {\n",
    "    \"min_length\": -1,\n",
    "    \"top_k\": 0.0,\n",
    "    \"top_p\": 1.0,\n",
    "    \"do_sample\": True,\n",
    "    \"pad_token_id\": tokenizer.eos_token_id,\n",
    "}\n",
    "\n",
    "\n",
    "for epoch, batch in tqdm(enumerate(ppo_trainer.dataloader)):\n",
    "    query_tensors = batch[\"input_ids\"]\n",
    "\n",
    "    #### Get response from gpt2\n",
    "    response_tensors = []\n",
    "    for query in query_tensors:\n",
    "        gen_len = output_length_sampler()\n",
    "        generation_kwargs[\"max_new_tokens\"] = gen_len\n",
    "        response = ppo_trainer.generate(query, **generation_kwargs)\n",
    "        response_tensors.append(response.squeeze()[-gen_len:])\n",
    "    batch[\"response\"] = [tokenizer.decode(r.squeeze()) for r in response_tensors]\n",
    "\n",
    "    #### Compute sentiment score\n",
    "    text = [q + r for q, r in zip(batch[\"query\"], batch[\"response\"])]\n",
    "    encoding = rm_tokenizer_trained(text, return_tensors=\"pt\",padding='max_length',truncation=True)\n",
    "    outputs = rm_model_trained(**encoding)\n",
    "    #rewards = outputs.logits.argmax(-1)\n",
    "    #rewards = [torch.tensor(i) for i in rewards]\n",
    "    rewards = [torch.tensor(i) for i in outputs.logits]\n",
    "\n",
    "#     texts = [q + r for q, r in zip(batch[\"query\"], batch[\"response\"])]\n",
    "#     pipe_outputs = sentiment_pipe(texts, **sent_kwargs)\n",
    "#     rewards = [torch.tensor(output[1][\"score\"]) for output in pipe_outputs]\n",
    "\n",
    "    #### Run PPO step\n",
    "    stats = ppo_trainer.step(query_tensors, response_tensors, rewards)\n",
    "    ppo_trainer.log_stats(stats, batch, rewards)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "4"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "2+2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "ename": "NameError",
     "evalue": "name 'sentiment_pipe' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[23], line 29\u001b[0m\n\u001b[1;32m     27\u001b[0m \u001b[38;5;66;03m#### sentiment analysis of query/response pairs before/after\u001b[39;00m\n\u001b[1;32m     28\u001b[0m texts \u001b[38;5;241m=\u001b[39m [q \u001b[38;5;241m+\u001b[39m r \u001b[38;5;28;01mfor\u001b[39;00m q, r \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(game_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquery\u001b[39m\u001b[38;5;124m\"\u001b[39m], game_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mresponse (before)\u001b[39m\u001b[38;5;124m\"\u001b[39m])]\n\u001b[0;32m---> 29\u001b[0m game_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrewards (before)\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m [output[\u001b[38;5;241m1\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscore\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m output \u001b[38;5;129;01min\u001b[39;00m \u001b[43msentiment_pipe\u001b[49m(texts, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39msent_kwargs)]\n\u001b[1;32m     31\u001b[0m texts \u001b[38;5;241m=\u001b[39m [q \u001b[38;5;241m+\u001b[39m r \u001b[38;5;28;01mfor\u001b[39;00m q, r \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(game_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mquery\u001b[39m\u001b[38;5;124m\"\u001b[39m], game_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mresponse (after)\u001b[39m\u001b[38;5;124m\"\u001b[39m])]\n\u001b[1;32m     32\u001b[0m game_data[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrewards (after)\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;241m=\u001b[39m [output[\u001b[38;5;241m1\u001b[39m][\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mscore\u001b[39m\u001b[38;5;124m\"\u001b[39m] \u001b[38;5;28;01mfor\u001b[39;00m output \u001b[38;5;129;01min\u001b[39;00m sentiment_pipe(texts, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39msent_kwargs)]\n",
      "\u001b[0;31mNameError\u001b[0m: name 'sentiment_pipe' is not defined"
     ]
    }
   ],
   "source": [
    "# #### get a batch from the dataset\n",
    "# bs = 16\n",
    "# game_data = dict()\n",
    "# dataset.set_format(\"pandas\")\n",
    "# df_batch = dataset[:].sample(bs)\n",
    "# game_data[\"query\"] = df_batch[\"query\"].tolist()\n",
    "# query_tensors = df_batch[\"input_ids\"].tolist()\n",
    "\n",
    "# response_tensors_ref, response_tensors = [], []\n",
    "\n",
    "# #### get response from gpt2 and gpt2_ref\n",
    "# for i in range(bs):\n",
    "#     gen_len = output_length_sampler()\n",
    "#     output = ref_model.generate(\n",
    "#         torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
    "#     ).squeeze()[-gen_len:]\n",
    "#     response_tensors_ref.append(output)\n",
    "#     output = model.generate(\n",
    "#         torch.tensor(query_tensors[i]).unsqueeze(dim=0).to(device), max_new_tokens=gen_len, **gen_kwargs\n",
    "#     ).squeeze()[-gen_len:]\n",
    "#     response_tensors.append(output)\n",
    "\n",
    "# #### decode responses\n",
    "# game_data[\"response (before)\"] = [tokenizer.decode(response_tensors_ref[i]) for i in range(bs)]\n",
    "# game_data[\"response (after)\"] = [tokenizer.decode(response_tensors[i]) for i in range(bs)]\n",
    "\n",
    "# #### sentiment analysis of query/response pairs before/after\n",
    "# texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (before)\"])]\n",
    "# game_data[\"rewards (before)\"] = [output[1][\"score\"] for output in sentiment_pipe(texts, **sent_kwargs)]\n",
    "\n",
    "# texts = [q + r for q, r in zip(game_data[\"query\"], game_data[\"response (after)\"])]\n",
    "# game_data[\"rewards (after)\"] = [output[1][\"score\"] for output in sentiment_pipe(texts, **sent_kwargs)]\n",
    "\n",
    "# # store results in a dataframe\n",
    "# df_results = pd.DataFrame(game_data)\n",
    "# df_results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "mean:\n"
     ]
    },
    {
     "ename": "NameError",
     "evalue": "name 'df_results' is not defined",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
      "Cell \u001b[0;32mIn[24], line 2\u001b[0m\n\u001b[1;32m      1\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmean:\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 2\u001b[0m display(\u001b[43mdf_results\u001b[49m[[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrewards (before)\u001b[39m\u001b[38;5;124m\"\u001b[39m, \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mrewards (after)\u001b[39m\u001b[38;5;124m\"\u001b[39m]]\u001b[38;5;241m.\u001b[39mmean())\n\u001b[1;32m      3\u001b[0m \u001b[38;5;28mprint\u001b[39m()\n\u001b[1;32m      4\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mmedian:\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n",
      "\u001b[0;31mNameError\u001b[0m: name 'df_results' is not defined"
     ]
    }
   ],
   "source": [
    "# print(\"mean:\")\n",
    "# display(df_results[[\"rewards (before)\", \"rewards (after)\"]].mean())\n",
    "# print()\n",
    "# print(\"median:\")\n",
    "# display(df_results[[\"rewards (before)\", \"rewards (after)\"]].median())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ak_rlhf",
   "language": "python",
   "name": "ak_rlhf"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
